<div class="doc-body">
    <section class="hero">
        <img src="./_static/maxtext.png" class="hero-image">
        <section class="hero-text">
            <div>
                <span><h1>MaxText</h1></span>
                <h3>High performance, highly scalable, open-source LLM library and reference implementation written in pure Python/JAX and targeting Google Cloud TPUs and GPUs for training.</h3>
                <div class="hero-cta">
                <a class="button button-primary" href="./tutorials.html">Get started</a>
                </div>
            </div>
        </section>
    </section>
    <section class="three-up">
        <div>
            <h3>High-performance</h3>
            <p>MaxText achieves high Model FLOPs Utilization (MFU) and tokens/second from single host to very large clusters while staying simple and largely "optimization-free" thanks to the power of JAX and the XLA compiler.</p>
        </div>
        <div>
            <h3>Pre-training</h3>
            <p>MaxText provides opinionated implementations for how to achieve optimal performance across a wide variety of dimensions like sharding, quantization, and checkpointing.</p>
        </div>
        <div>
            <h3>Post-training</h3>
            <p>MaxText provides a scalable framework to fine-tune proprietary or OSS models using state-of-the-art Reinforcement Learning (RL) algorithms (e.g., GRPO) and techniques (e.g. SFT, Knowledge Distillation, etc).</p>
        </div>
    </section>
    <section class="banner">
        <h3>JAX AI Stack</h3>
        <p>The JAX AI Stack is a curated collection of libraries that researchers and engineers, both inside and outside of Google, have found useful for implementing and deploying the models behind generative AI tools like Imagen, Gemini, and more.</p>
        <ul>
            <li><a href="http://jax.dev/">JAX</a> - core array operations and program transformations</li>
            <li><a href="https://flax.readthedocs.io/en/latest/">Flax</a> - For building neural networks</li>
            <li><a href="https://orbax.readthedocs.io/en/latest/">Orbax</a> - For checkpointing and persistence utilities</li>
            <li><a href="https://optax.readthedocs.io/en/latest/">Optax</a> - For gradient processing and optimization</li>
            <li><a href="https://tunix.readthedocs.io/en/latest/">Tunix</a> - A JAX Library with the latest experimental algorithms and post-training techniques</li>
            <li><a href="https://github.com/jax-ml/ml_dtypes">ml_dtypes</a> - NumPy dtype extensions for machine learning.</li>
            <li><a href="https://maxtext.readthedocs.io/en/latest/index.html#model-library">MaxText model library</a> for JAX LLMs highly optimized for TPUs</li>
            <li><a href="https://blog.vllm.ai/2025/10/16/vllm-tpu.html">vLLM on TPU</a> for high performance sampling (inference) for Reinforcement Learning (RL)</li>
            <li><a href="https://docs.cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro">Pathways</a> for multi-host inference (sampling) and highly efficient weight transfer</li>
            <li>Optional data loading libraries (<a href="https://google-grain.readthedocs.io/en/latest/">Grain</a> or <a href="https://www.tensorflow.org/guide/data">tf.data</a>)</li>
        </ul>
    </section>
</div>
